{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import random\n",
    "import numpy as np\n",
    "import time\n",
    "from gym.envs.registration import register\n",
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Observation space: Discrete(16)\n",
      "Action space: Discrete(4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "gym.spaces.discrete.Discrete"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# https://github.com/openai/gym/blob/master/gym/envs/toy_text/frozen_lake.py\n",
    "# https://github.com/openai/gym/blob/master/gym/envs/__init__.py\n",
    "try:\n",
    "    register(\n",
    "        id='FrozenLakeNoSlip-v0',\n",
    "        entry_point='gym.envs.toy_text:FrozenLakeEnv',\n",
    "        kwargs={'map_name' : '4x4', 'is_slippery':False},\n",
    "        max_episode_steps=100,\n",
    "        reward_threshold=0.78, # optimum = .8196\n",
    "    )\n",
    "except:\n",
    "    pass\n",
    "\n",
    "# env_name = \"CartPole-v1\"\n",
    "# env_name = \"MountainCar-v0\"\n",
    "# env_name = \"MountainCarContinuous-v0\"\n",
    "# env_name = \"Acrobot-v1\"\n",
    "# env_name = \"Pendulum-v0\"\n",
    "env_name = \"FrozenLake-v0\"\n",
    "env_name = \"FrozenLakeNoSlip-v0\"\n",
    "env = gym.make(env_name)\n",
    "print(\"Observation space:\", env.observation_space)\n",
    "print(\"Action space:\", env.action_space)\n",
    "type(env.action_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Agent():\n",
    "    def __init__(self, env):\n",
    "        self.is_discrete = \\\n",
    "            type(env.action_space) == gym.spaces.discrete.Discrete\n",
    "        \n",
    "        if self.is_discrete:\n",
    "            self.action_size = env.action_space.n\n",
    "            print(\"Action size:\", self.action_size)\n",
    "        else:\n",
    "            self.action_low = env.action_space.low\n",
    "            self.action_high = env.action_space.high\n",
    "            self.action_shape = env.action_space.shape\n",
    "            print(\"Action range:\", self.action_low, self.action_high)\n",
    "        \n",
    "    def get_action(self, state):\n",
    "        if self.is_discrete:\n",
    "            action = random.choice(range(self.action_size))\n",
    "        else:\n",
    "            action = np.random.uniform(self.action_low,\n",
    "                                       self.action_high,\n",
    "                                       self.action_shape)\n",
    "        return action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Action size: 4\n",
      "State size: 16\n"
     ]
    }
   ],
   "source": [
    "class QAgent(Agent):\n",
    "    def __init__(self, env, discount_rate=0.97, learning_rate=0.01):\n",
    "        super().__init__(env)\n",
    "        self.state_size = env.observation_space.n\n",
    "        print(\"State size:\", self.state_size)\n",
    "        \n",
    "        self.eps = 1.0\n",
    "        self.discount_rate = discount_rate\n",
    "        self.learning_rate = learning_rate\n",
    "        self.build_model()\n",
    "        \n",
    "    def build_model(self):\n",
    "        self.q_table = 1e-4*np.random.random([self.state_size, self.action_size])\n",
    "        \n",
    "    def get_action(self, state):\n",
    "        q_state = self.q_table[state]\n",
    "        action_greedy = np.argmax(q_state)\n",
    "        action_random = super().get_action(state)\n",
    "        return action_random if random.random() < self.eps else action_greedy\n",
    "    \n",
    "    def train(self, experience):\n",
    "        state, action, next_state, reward, done = experience\n",
    "        \n",
    "        q_next = self.q_table[next_state]\n",
    "        q_next = np.zeros([self.action_size]) if done else q_next\n",
    "        q_target = reward + self.discount_rate * np.max(q_next)\n",
    "        \n",
    "        q_update = q_target - self.q_table[state,action]\n",
    "        self.q_table[state,action] += self.learning_rate * q_update\n",
    "        \n",
    "        if done:\n",
    "            self.eps = self.eps * 0.99\n",
    "        \n",
    "agent = QAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "s: 15 a: 2\n",
      "Episode: 99, Total reward: 99.0, eps: 0.006570483042414605\n",
      "  (Right)\n",
      "SFFF\n",
      "FHFH\n",
      "FFFH\n",
      "HFF\u001b[41mG\u001b[0m\n",
      "[[2.26657246e-04 5.78706808e-05 1.31003002e-01 4.79979472e-04]\n",
      " [1.04844339e-04 3.54687230e-05 2.55427402e-01 7.48193170e-04]\n",
      " [1.65884052e-04 4.36189268e-01 7.00960632e-05 8.88959016e-04]\n",
      " [1.17590145e-03 5.49928561e-05 6.38101134e-05 3.16729025e-05]\n",
      " [6.15095796e-05 4.13004395e-05 3.01916296e-05 5.90834920e-05]\n",
      " [4.25841839e-05 7.56731911e-06 8.26995681e-05 6.60735347e-05]\n",
      " [6.37688773e-05 6.48470611e-01 7.98525028e-05 3.01763264e-03]\n",
      " [5.30045086e-05 5.07371856e-05 1.99983098e-05 9.32633643e-05]\n",
      " [4.18707042e-05 5.71000704e-05 2.32037264e-04 3.18999734e-05]\n",
      " [4.15603587e-05 6.91116263e-03 2.75917720e-04 2.88792860e-05]\n",
      " [3.17844770e-04 8.42711089e-01 2.42033377e-05 2.93321581e-03]\n",
      " [9.02368275e-05 6.37119718e-05 6.00035959e-05 7.00083118e-05]\n",
      " [7.71263802e-05 1.73696049e-05 4.11883145e-05 3.22260325e-05]\n",
      " [5.22960464e-05 8.34910792e-05 1.07448006e-01 1.17173895e-04]\n",
      " [3.04657414e-03 6.80853272e-02 9.72068773e-01 1.34302488e-02]\n",
      " [1.35322054e-06 5.97967404e-05 5.91493712e-05 5.44460043e-05]]\n"
     ]
    }
   ],
   "source": [
    "total_reward = 0\n",
    "for ep in range(100):\n",
    "    state = env.reset()\n",
    "    done = False\n",
    "    while not done:\n",
    "        action = agent.get_action(state)\n",
    "        next_state, reward, done, info = env.step(action)\n",
    "        agent.train((state,action,next_state,reward,done))\n",
    "        state = next_state\n",
    "        total_reward += reward\n",
    "        \n",
    "        print(\"s:\", state, \"a:\", action)\n",
    "        print(\"Episode: {}, Total reward: {}, eps: {}\".format(ep,total_reward,agent.eps))\n",
    "        env.render()\n",
    "        print(agent.q_table)\n",
    "        time.sleep(0.05)\n",
    "        clear_output(wait=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
